#!/usr/bin/env python
# -*- coding: utf-8 -*-


import sys
import csv
import operator
import ast
import pickle
from math import log
from src.manager.log_manager import LogManager
from src.constant.file_and_path_constant import FileAndPathConstant

sys.path.append(FileAndPathConstant.System_Drive + '/github-repository/hades/dev-project/jormungandr')
Logger = LogManager.get_logger(__name__)


class Tree2Test:
    # 数据文件路径
    File_Path = FileAndPathConstant.System_Drive + "/github-repository/dataset/CarEvaluation/car.data"

    # 读文件的方式
    Open_Mode = "rt"

    # 所有列的取值字典
    All_Column_Value_Dict = {
        # Buying_Class_List
        "Buying_Class_List": ["vhigh", "high", "med", "low"],
        # Maint_Class_List
        "Maint_Class_List": ["vhigh", "high", "med", "low"],
        # Doors_Class_List
        "Doors_Class_List": ["2", "3", "4", "5more"],
        # Persons_Class_List
        "Persons_Class_List": ["2", "4", "more"],
        # Lug_Boot_Class_List
        "Lug_Boot_Class_List": ["small", "med", "big"],
        # Safety_Class_List
        "Safety_Class_List": ["low", "med", "high"]
    }

    # 列的名称
    Column_Value_List = ["Buying_Class_List", "Maint_Class_List", "Doors_Class_List",
                         "Persons_Class_List", "Lug_Boot_Class_List", "Safety_Class_List"]

    # 标签列表
    Label_List = ["unacc", "acc", "good", "vgood"]

    # 决策树存储路径
    Storage_Path = FileAndPathConstant.System_Drive + '/github-repository/hades/dev-project/jormungandr/test/tree/classifierStorage.txt'

    def __init__(self):
        pass

    def read_file(self, file_path, open_mode):
        """
        读文件，以二维数组形式返回
        """
        Logger.info("读文件，以二维数组形式返回")

        dataset = []
        with open(file_path, open_mode) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            for row_index, row_value in enumerate(reader):
                # 最后一行是个空的数组[]
                if len(row_value) > 0:
                    dataset.append(row_value)

            return dataset

    def calculate_shannon_entropy(self, dataset):
        """
        计算香浓熵
        """
        Logger.info("计算香浓熵")

        label_set = set()
        total_row_count = len(dataset)

        # 确定有几种label
        for row_index, column_value in enumerate(dataset):
            label_set.add(dataset[row_index][-1])

        # 计算每种label的数量
        class_count = self.calculate_label_count(dataset)

        # 计算香浓熵
        shannon_entropy = float()
        for key in class_count:
            probability = class_count[key] / total_row_count
            shannon_entropy += probability * log(probability, 2)

        return -shannon_entropy

    def split_dataset(self, dataset, characteristic_value_index, characteristic_value):
        """
        划分数据集
        """
        Logger.info("划分数据集")
        try:
            sub_dataset = []
            for row in dataset:
                if len(row) == characteristic_value_index:
                    pass

                if row[characteristic_value_index] == characteristic_value:
                    sub_dataset_row = row[:characteristic_value_index]
                    sub_dataset_row.extend(row[characteristic_value_index + 1:])
                    sub_dataset.append(sub_dataset_row)
        except IndexError as e:
            Logger.error("except:", e)

        return sub_dataset

    def calculate_label_count(self, dataset):
        """
        计算每种标签的数量
        """
        Logger.info("计算每种标签的数量")

        class_count = dict()
        for row in dataset:
            if row[-1] in class_count.keys():
                class_count[row[-1]] += 1
            else:
                class_count[row[-1]] = 1

        return class_count

    def choose_best_feature_to_split(self, dataset, all_column_value_dict):
        """
        选择最好的数据集的划分方式，返回特征值的索引、列名和香浓熵
        """
        Logger.info("选择最好的数据集的划分方式，返回特征值的索引")

        all_shannon_entropy_dict = dict()
        for index, value in enumerate(all_column_value_dict):
            column_value_list = all_column_value_dict[value]

            # 对第index列划分数据集，并计算香浓熵
            column_shannon_entropy = float()
            for column_value in column_value_list:
                # 对第index列，按照值column_value划分数据集，并计算香浓熵
                sub_dataset = self.split_dataset(dataset, index, column_value)
                sub_dataset_shannon_entropy = self.calculate_shannon_entropy(sub_dataset)
                # 注意此处必须乘以len(sub_dataset) / len(dataset)，因为熵是信息的期望
                column_shannon_entropy -= len(sub_dataset) / len(dataset) * sub_dataset_shannon_entropy
            all_shannon_entropy_dict[value] = column_shannon_entropy

        # 打印所有划分方式的特征值的索引、列名和香浓熵
        for index, value in enumerate(all_shannon_entropy_dict):
            Logger.info("特征值索引【%d】，特征值名称【%s】，香浓熵为【%f】"
                        % (index, value, all_shannon_entropy_dict[value]))

        # 选择最好的数据集的划分方式
        best_feature_name = None
        best_feature_index = None
        min_shannon_entropy = self.calculate_shannon_entropy(dataset)
        for index, value in enumerate(all_shannon_entropy_dict):
            if min_shannon_entropy > all_shannon_entropy_dict[value]:
                min_shannon_entropy = all_shannon_entropy_dict[value]
                best_feature_name = value
                best_feature_index = index
        Logger.info("最好的数据集的划分方式：特征值索引【%d】，特征值名称【%s】，香浓熵为【%f】"
                    % (best_feature_index, best_feature_name, min_shannon_entropy))

        return best_feature_index, best_feature_name, min_shannon_entropy

    def count_and_sort_label_list(self, label_list):
        """
        计算每个标签出现的次数，降序排列
        """
        Logger.info("计算每个标签出现的次数，降序排列")

        label_count_dict = dict()
        # 计算每个标签出现的次数
        for label in label_list:
            if label not in label_count_dict.keys():
                label_count_dict[label] = 0
            label_count_dict[label] += 1

        # 降序排列
        return sorted(label_count_dict.items(), key=operator.itemgetter(1), reverse=True)

    def create_tree(self, dataset, all_column_value_dict):
        """
        创建树
        """
        Logger.info("创建树")

        class_list = [example[-1] for example in dataset]

        # 递归函数的第一个停止条件是所有的类标签完全相同，则直接返回该类标签
        if class_list.count(class_list[0]) == len(class_list):
            return class_list[0]

        # 递归函数的第二个停止条件是使用完了所有特征，仍然不能将数据集划分成仅包含唯一类别的分组
        if len(dataset[0]) == 1:
            self.count_and_sort_label_list(class_list)

        # 选择最好的数据集的划分方式
        best_feature_index, best_feature_name, min_shannon_entropy = self.choose_best_feature_to_split(dataset,
                                                                                                       all_column_value_dict)
        if best_feature_index == None and best_feature_name == None:
            return None

        # 从所有列中删除用来划分数据集的这个列，下次再划分数据集时就没有这个列了
        del (all_column_value_dict[best_feature_name])

        # 提取特征值列，并去重
        best_feature_value_list = [example[best_feature_index] for example in dataset]
        unique_best_feature_value_set = set(best_feature_value_list)

        # 递归调用，构建树
        my_tree = {best_feature_name: {}}
        for best_feature_value in unique_best_feature_value_set:
            sub_column_value_dict = all_column_value_dict.copy()
            # 划分数据集
            sub_dataset = self.split_dataset(dataset, best_feature_index, best_feature_value)
            # 生成树
            my_tree[best_feature_name][best_feature_value] = self.create_tree(sub_dataset, sub_column_value_dict)

        return my_tree

    def classify(self, input_tree, feature_list, test_list):
        """
        使用决策树分类
        """
        Logger.info("使用决策树分类")

        for index, value in input_tree.items():
            second_dict = input_tree[index]
            feature_index = feature_list.index(index)
            for key in second_dict.keys():
                if test_list[feature_index] == key:
                    if type(second_dict[key]).__name__ == 'dict':
                        label = self.classify(second_dict[key], feature_list, test_list)
                    else:
                        label = second_dict[key]
            return label

    def write_tree(self, input_tree, filename):
        """
        存储决策树
        :param input_tree:
        :param filename:
        :return:
        """
        fw = open(filename, 'wb')
        pickle.dump(input_tree, fw)
        fw.close()

    def read_tree(self, filename):
        """
        从文件中提取决策树
        :param filename:
        :return:
        """
        fr = open(filename, 'rb+')
        return pickle.load(fr)


if __name__ == "__main__":
    tree2_test = Tree2Test()

    # 读文件，以二维数组形式返回
    dataset = tree2_test.read_file(Tree2Test.File_Path, Tree2Test.Open_Mode)

    # 创建树
    my_tree = tree2_test.create_tree(dataset, Tree2Test.All_Column_Value_Dict)
    Logger.info("决策树为【%s】" % str(my_tree))

    # 使用决策树分类
    test_list = ["low", "med", "4", "more", "med", "high"]
    label = tree2_test.classify(my_tree, Tree2Test.Column_Value_List, test_list)
    Logger.info("类型为【%s】" % label)

    # 存储、读取决策树
    tree2_test.write_tree(my_tree, Tree2Test.Storage_Path)
    my_tree_str = tree2_test.read_tree(Tree2Test.Storage_Path)
    Logger.info("决策树为【%s】" % my_tree_str)
